import itertools

import cv2
import numpy as np
import pptk
from sklearn.cluster import KMeans, MeanShift
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
import matplotlib.pyplot as plt

col_match_t1 = 1.5 ** 2
col_match_t2 = 4 ** 2
cube_dist = 60e-3 ** 2
object_bandwidth = 30e-3
grasper_dist = 60e-3

max_limit = 0.7
min_limit = 0.3


def reduce_pts(rgb, xyz):
    c = np.logical_and.reduce([~np.isnan(xyz[:, :, 0]),
                               xyz[:, :, 2] > min_limit,
                               xyz[:, :, 2] < max_limit,
                               xyz[:, :, 2] != 0])
    l = np.where(c)
    return rgb[l], xyz[l]


def find_pts(rgb, xyz, color, stddev, debug=False):
    """ Returns XYZ points corresponding to the cube of given color

    :param rgb: n x 3 array of RGB color vectors
    :param xyz: n x 3 array of XYZ point cloud data
    :param color: 3d vector of color in the Lab colorspace
    :param stddev: 3d vector of stddevs of color in the Lab colorspace
    :param debug: debug flag for displaying pointcloud
    :return: list of m x 3 array of XYZ point cloud data corresponding to one cube
    """
    color = np.array(color, dtype=np.float32)
    lab = cv2.cvtColor(rgb[:, np.newaxis, :], cv2.COLOR_RGB2LAB).astype(np.float32)[:, 0, :]
    stddev = np.divide(1, stddev).astype(np.float32)
    col_dist = np.sum(np.square((lab - color) * stddev), axis=-1)

    l_col_match = np.asarray(col_dist < col_match_t1).nonzero()[0]

    pts_list = []

    if (len(l_col_match) == 0):
        return pts_list

    ms = MeanShift(bandwidth=2 * object_bandwidth, bin_seeding=True)
    ms.fit(xyz[l_col_match])
    ms_labels = ms.labels_
    labels_unique = np.unique(ms_labels)
    for label in labels_unique:
        l_cluster = np.asarray(ms_labels == label).nonzero()[0]
        if len(l_cluster) < 10:
            continue
        col_center = np.average(xyz[l_col_match[l_cluster]], axis=0)
        spatial_dist = np.sum((xyz - col_center) ** 2, axis=-1)

        c_final = np.logical_and.reduce([col_dist < col_match_t2, spatial_dist < cube_dist])

        if debug:
            viewer_cols = np.where(np.repeat(c_final[:, np.newaxis], 3, axis=1),
                                   np.repeat([[0, 1, 0]], len(rgb), axis=0),
                                   np.repeat([[0, 0, 1]], len(rgb), axis=0))
            viewer = pptk.viewer(xyz)
            viewer.attributes(viewer_cols, rgb / 255)
            # viewer.set(point_size=0.1)
            viewer.wait()
            viewer.close()

        l_final = np.asarray(c_final).nonzero()[0]

        if len(l_final) > 100:
            pts_list.append(xyz[l_final])

    return pts_list


def find_pts_grasper(rgb, xyz, color, stddev, debug=False):
    """ Returns XYZ points corresponding to the cube of given color

    :param rgb: n x 3 array of RGB color vectors
    :param xyz: n x 3 array of XYZ point cloud data
    :param color: 3d vector of color in the Lab colorspace
    :param stddev: 3d vector of stddevs of color in the Lab colorspace
    :param debug: debug flag for displaying pointcloud
    :return: list of m x 3 array of XYZ point cloud data corresponding to one cube
    """
    color = np.array(color, dtype=np.float32)
    lab = cv2.cvtColor(rgb[:, np.newaxis, :], cv2.COLOR_RGB2LAB).astype(np.float32)[:, 0, :]
    stddev = np.divide(1, stddev).astype(np.float32)
    col_dist = np.sum(np.square((lab - color) * stddev), axis=-1)

    l_col_match = np.asarray(col_dist < col_match_t1).nonzero()[0]

    pts_list = []

    if (len(l_col_match) == 0):
        return pts_list

    col_center = np.average(xyz[l_col_match], axis=0)
    spatial_dist = np.sum((xyz - col_center) ** 2, axis=-1)
    c_final = np.logical_and.reduce([col_dist < col_match_t2, spatial_dist < grasper_dist])

    if debug:
        viewer_cols = np.where(np.repeat(c_final[:, np.newaxis], 3, axis=1),
                               np.repeat([[0, 1, 0]], len(rgb), axis=0),
                               np.repeat([[0, 0, 1]], len(rgb), axis=0))
        viewer = pptk.viewer(xyz)
        viewer.attributes(viewer_cols, rgb / 255)
        # viewer.set(point_size=0.1)
        viewer.wait()
        viewer.close()

    l_final = np.asarray(c_final).nonzero()[0]
    pts_list.append(xyz[l_final])

    return pts_list


def get_plane_3pts(p):
    """ Returns a plane vector diven three points

    :param p: array of shape 3 x n x 3, where n is number of planes, vectors are in the last dim
    :return: n x 4 array with plane spec such in each row coresponds to a, b, c, d in a*x + b*y + c*z + d = 0
    """
    v1 = p[2] - p[0]
    v2 = p[1] - p[0]
    cp = np.cross(v1, v2)
    d = -np.expand_dims(np.sum(cp * p[2], axis=-1), axis=-1)
    return np.concatenate([cp, d], axis=-1)


def find_first_plane(xyz_h, prev_corner_list=None):
    """ Performs Ransac to find the best plane

    :param xyz_h: n x 4 array with 3d vectors in homogenous coordinates
    :param prev_point_list: list of arrays of dim 3+ x 3 of previously detected plane
    :return: bestP - 4d array with plane spec such that a*x + b*y + c*z + d = 0,
    :return: r_xyz_h - k x 4 array with 3d vectors in the plane in homogenous coords
    :return: t_xyz_h - m x 4 array with 3d vectors outside the plane in homogenous coords
    :return: plane_conf - float 0..1 for confidence of the predicted plane

    """
    n = xyz_h.shape[0]
    if prev_corner_list is None or len(prev_corner_list) == 0:
        row_i = np.random.choice(n, [3, 64])
        pts_for_plane = xyz_h[row_i, :3]
    else:
        row_i = np.random.choice(n, [3, 64 - len(prev_corner_list)])
        pts_for_plane = xyz_h[row_i, :3]
        for prev_corners in prev_corner_list:
            pts_for_plane = np.concatenate(
                [pts_for_plane, np.reshape(np.array(prev_corners[:3], dtype=np.float32), [3, 1, 3])], axis=1)

    P = get_plane_3pts(pts_for_plane)
    d = np.abs(np.matmul(xyz_h, P.T))
    n = np.sqrt(np.sum(P[:, :3] ** 2, axis=-1))
    l = d < 1e-3 * n
    good = np.sum(l, axis=0)
    best_idx = np.argmax(good)
    bestP = P[best_idx]
    l = np.abs(np.matmul(xyz_h, bestP)) < 2e-3 * np.sqrt(np.sum(bestP[:3] ** 2))

    t_xyz_h = xyz_h[(l).nonzero()[0], :]
    r_xyz_h = xyz_h[(~l).nonzero()[0], :]

    plane_conf = 1 - np.mean(np.square(np.matmul(t_xyz_h, bestP))) / (2e-3**2 * np.sum(bestP[:3] ** 2))

    return np.expand_dims(bestP, axis=-1), r_xyz_h, t_xyz_h, plane_conf


def find_hull_direction(xy, num_repeats=10, debug=False):
    """ Return the direction vector for the new coordinate system based on the longest lines in the convex hull
        The calculation is based on a circular mean of angles reduced to the first quadrant:
        https://en.wikipedia.org/wiki/Mean_of_circular_quantities

        To asses the accuracy the circular variance R (http://www.fiserlab.org/manuals/procheck/manual/man_cv.html) is
        used to output arbitrarily defined confidence 1 - Var(alpha)**2

        The mean is taken over num_repeats hulls of the points with 1/num_repeats sampling.
        This leads to O(n * log(n/num_repeats)) complexity

    :param xy: n x 2 array with 2d poitns
    :param num_repeats: number of hull samplings
    :return: alpha_final: 2d unit vector of the extracted direction, conf the confidence value 0..1
    """
    if debug:
        plt.scatter(xy[:, 0], xy[:, 1])

    s = 0
    c = 0
    ns = 0

    for _ in range(num_repeats):
        # we sample random points a construct a hull
        hull_pts = cv2.convexHull(xy[np.random.choice(xy.shape[0], xy.shape[0] // num_repeats)].astype(np.float32))
        if debug:
            plt.scatter(hull_pts[:, 0, 0], hull_pts[:, 0, 1])

        ns += hull_pts.shape[0]
        lines = hull_pts[:, 0, :] - np.roll(hull_pts[:, 0, :], -1, axis=0)
        alphas = np.arctan2(lines[:, 0], lines[:, 1])
        # we force angles to occupy the first quadrant to resolve the different directions on different sides
        alphas = np.mod(alphas, np.pi/4)
        # to properly calculate the circular mean we have to scale the angles to the full range 0..2*Pi
        alphas *= 4
        s += np.sum(np.sin(alphas))
        c += np.sum(np.cos(alphas))
    if debug:
        plt.show()
    # after calculation we scale the angle back e.g. divide by 4
    alpha_final = np.arctan2(c / ns, s / ns) / 4
    final_direction = np.array([np.cos(alpha_final), np.sin(alpha_final)])

    circular_variance = 1 - np.sqrt(s**2 + c**2) / ns
    conf = 1 - circular_variance**2

    return final_direction, conf


cube_side = 24.0e-3
cuboid_sides = [12.0e-3, 24.0e-3, 50.0e-3]


def get_cuboid_dims(x, y):
    """ Returns cuboid dimesnsions. Essentially just bunch of cases for different orientations.
        This calculates an arbitrary confidence values which are calculated as:

        1 - sqrt(|x_dist - x_dim|/|x_dim|^2 + |y_dist - y_dim|/|y_dim|^2)

        These values are calculated for all the possible cuboids and orientations and the output is chosen to maximize
        this value.

    :param x: Array of x-coordinates in the oriented plane of one of the faces
    :param y: Array of y-coordinates in the oriented plane of one of the faces
    :return: x_m, y_m, x_dim, y_dim, z_dim, obj_type
             x_m, y_m: center of the frontal face
             x_dim, y_dim, z_dim: dimensions of the object
             obj_type: type string 'cube' or 'cuboid'
             dims_conf: confidence value for the cuboids
    """

    # We use percentiles for a more robust approximation
    x_min, x_max = np.percentile(x, [5, 95])
    y_min, y_max = np.percentile(y, [5, 95])

    # Get cuboid dimensions
    x_dist = x_max - x_min
    y_dist = y_max - y_min

    # Get cuboid center
    x_m = np.divide(x_min + x_max, 2, dtype=np.float32)
    y_m = np.divide(y_min + y_max, 2, dtype=np.float32)

    # To obtain max confidence we need to minimize the value in sqrt from the definition
    cuboid_dists = []
    cuboid_dists.append(((x_dist - cube_side) / cube_side)**2 + ((y_dist - cube_side) / cube_side)**2)

    # For cuboids we test 6 different orientations
    cuboid_sides_perm = list(itertools.permutations(cuboid_sides))

    cuboid_dists.extend([((x_dist - cuboid_side_1) / cuboid_side_1)**2 +
                          ((y_dist - cuboid_side_2) / cuboid_side_2)**2
                          for cuboid_side_1, cuboid_side_2, _ in cuboid_sides_perm])

    idx = np.argmin(cuboid_dists)
    def get_conf(i):
        return 1 - np.sqrt(cuboid_dists[i])

    # If we get cube
    if idx == 0:
        return x_m, y_m, cube_side, cube_side, cube_side, 'cube', max(get_conf(0), 0)
    # If we get cuboid
    else:
        return (x_m, y_m, *cuboid_sides_perm[idx - 1], 'cuboid', max(get_conf(idx), 0))


def project_cuboid_onto_plane(P, pts, plane_conf, color_name):
    """ Projects points onto plane and locates the cube corners

    :param P: 4d array with plane spec such that a*x + b*y + c*z + d = 0
    :param pts: n x 4 array with 3d vectors in the plane in homogenous coords
    :param color_name: string with the name of the color
    :return: 8 x 3 array with 8 XYZ corners of the cube
    """

    # Project the points onto a plane
    xyz_c = np.mean(pts, axis=0)
    dist = np.matmul(xyz_c, P)
    Pn = np.divide(P[:3], np.sum(np.power(P[:3], 2), axis=0))
    xyz_c = xyz_c[:3] - dist * Pn.T

    # Get two perp vectors in plane as a base for new coord system
    p1 = np.array([P[2, 0], 0, -P[0, 0]], dtype=np.float32)
    p1 = p1 / np.linalg.norm(p1)
    p2 = np.cross(p1, P[:3, 0].T)
    p2 = p2 / np.linalg.norm(p2)

    # Get x,y coords in the new coord system
    x = np.sum(p1 * (pts[:, :3] - xyz_c), axis=1)
    y = np.sum(p2 * (pts[:, :3] - xyz_c), axis=1)

    xy = np.column_stack([x, y])

    # Find direction which align best with the longest line in the hull of points in the new system
    direction, direction_conf = find_hull_direction(xy)

    # Use the direction to find another coordinate system aligned with the cube
    pp1 = np.matmul(np.column_stack([p1, p2]), direction)
    pp1 = pp1 / np.linalg.norm(pp1)
    pp2 = np.cross(pp1, P[:3, 0].T)
    pp2 = pp2 / np.linalg.norm(pp2)

    # Transform points to the final 2d coordinate system
    x = np.sum(pp1 * (pts[:, :3] - xyz_c), axis=1)
    y = np.sum(pp2 * (pts[:, :3] - xyz_c), axis=1)

    # Get cuboid dimensions
    x_m, y_m, x_dim, y_dim, z_dim, obj_type, dim_conf = get_cuboid_dims(x,y)

    xy_corners = np.array([[x_m - x_dim / 2, y_m - y_dim / 2],
                           [x_m - x_dim / 2, y_m + y_dim / 2],
                           [x_m + x_dim / 2, y_m + y_dim / 2],
                           [x_m + x_dim / 2, y_m - y_dim / 2]], dtype=np.float32)

    # Assume that the other four corners are in the z direction and add them to the output
    corners = xyz_c + np.matmul(np.column_stack([pp1, pp2]), xy_corners.T).T
    nP = P[:3, 0].T / np.linalg.norm(P[:3, 0].T)
    if nP[2] < 0:
        nP = -nP
    corners = np.row_stack([corners, corners + z_dim * nP])

    print("Color: {}, plane: {}, direction: {}, dim: {}".format(color_name, plane_conf, direction_conf, dim_conf))
    total_conf = (plane_conf * direction_conf * dim_conf) ** (1/3)

    obj = {'type': obj_type, 'corners': corners, 'color': color_name, 'conf': total_conf }
    return obj


def find_object_ransac(xyz, color_name, prev_objects=None, debug=False):
    """ Returns the object dict using ransac-based methods for given points

    :param xyz: n x 3 array XYZ pointcloud data of one cube
    :param color_name: name of color for obj creation
    :param prev_corners: list of 8 3d points of previously detected corners or None
    :param debug: debug flag to display the pointcloud data
    :return: object dict with shape dependent on the detected type
    """
    xyz_h = np.concatenate([xyz, np.ones([xyz.shape[0], 1], dtype=np.float32)], axis=1)

    prev_corners = [obj['corners'] for obj in prev_objects if obj['type'] == 'cube' or obj['type'] == 'cuboid']
    P1, r_xyz_h, t_xyz_h, plane_conf = find_first_plane(xyz_h, prev_corners)

    # print(t_xyz_h.shape[0])

    if t_xyz_h.shape[0] < 225:
        return None

    obj = project_cuboid_onto_plane(P1, xyz_h, plane_conf, color_name)

    if debug:
        pts = np.concatenate([t_xyz_h[:, :3], r_xyz_h[:, :3], obj['corners']], axis=0)
        a = np.concatenate([np.repeat([[0, 1, 0]], len(t_xyz_h), axis=0),
                            np.repeat([[1, 0, 0]], len(r_xyz_h), axis=0),
                            np.repeat([[0, 0, 1]], 8, axis=0)], axis=0)
        viewer = pptk.viewer(pts[:, :3])
        viewer.attributes(a)
        viewer.set(point_size=0.0005)
        viewer.wait()
        viewer.close()

    return obj


def project_grasper_onto_plane(P, pts):
    """ Returns the points of the grasper in this order:
        ____       _____
       \    |     |    /
        \   |1   3|   /
         \  |     |  /
          \_|0   2|_/

    Top points are 10 mm above the bottom ones

    :param P: 4d array with plane spec such that a*x + b*y + c*z + d = 0
    :param pts: n x 4 array with 3d vectors in the plane in homogenous coords
    :return: 4 x 3 array with 8 XYZ points of the grasper as in the description
    """
    # Project the points onto a plane
    xyz_c = np.mean(pts, axis=0)
    dist = np.matmul(xyz_c, P)
    Pn = np.divide(P[:3], np.sum(np.power(P[:3], 2), axis=0))
    xyz_c = xyz_c[:3] - dist * Pn.T

    # Get two perp vectors in plane as a base for new coord system
    p1 = np.array([P[2, 0], 0, -P[0, 0]], dtype=np.float32)
    p1 = p1 / np.linalg.norm(p1)
    p2 = np.cross(p1, P[:3, 0].T)
    p2 = p2 / np.linalg.norm(p2)

    x = np.sum(p1 * (pts[:, :3] - xyz_c), axis=1)
    y = np.sum(p2 * (pts[:, :3] - xyz_c), axis=1)

    xy = np.column_stack([x, y])

    # plt.scatter(xy[:, 0], xy[:, 1])
    # plt.show()

    labels = KMeans(n_clusters=2).fit_predict(xy)
    lda = LDA()
    lda.fit(xy, labels)
    w = lda.coef_

    # Get two new perp vetors
    pp1 = w[0, 0] * p1 + w[0, 1] * p2
    pp1 = pp1 / np.linalg.norm(pp1)
    pp2 = np.cross(pp1, P[:3, 0].T)
    pp2 = pp2 / np.linalg.norm(pp2)

    # Transform points to the final 2d coordinate system
    x = np.sum(pp1 * (pts[:, :3] - xyz_c), axis=1)
    y = np.sum(pp2 * (pts[:, :3] - xyz_c), axis=1)

    # make sure that the grasper is not flipped
    if np.median(y) < 0:
        pp2 = np.cross(-pp1, P[:3, 0].T)
        pp2 = pp2 / np.linalg.norm(pp2)
        y = np.sum(pp2 * (pts[:, :3] - xyz_c), axis=1)

    x_0_min = np.percentile(x[labels == 0], 5)
    x_1_min = np.percentile(x[labels == 1], 5)

    # ensure we known whcih cluster is which
    if x_0_min < x_1_min:
        label_left = 0
        label_right = 1
    else:
        label_left = 1
        label_right = 0

    y_bottom = np.min(y)
    y_top = y_bottom + 0.010

    x_left = np.percentile(x[labels == label_left], 95)
    x_right = np.percentile(x[labels == label_right], 5)

    points = np.array([[x_left, y_top], [x_left, y_bottom], [x_right, y_top], [x_right, y_bottom]], dtype=np.float32)

    # plt.scatter(x, y)
    # plt.scatter(points[:, 0], points[:, 1])
    # plt.show()

    # reproject to 3D
    points = xyz_c + np.matmul(np.column_stack([pp1, pp2]), points.T).T
    return points


def find_grasper(xyz, prev_points=None, debug=False):
    """ Returns the points of the grasper in this order:
        ____       _____
       \    |     |    /
        \   |1   3|   /
         \  |     |  /
          \_|0   2|_/

    Top points are 10 mm above the bottom ones, the bottom ones should correspond to the bottom of the grasper

    In case of failure returns None

    :param xyz: n x 3 array XYZ pointcloud data of one cube
    :param prev_points: list of 4 3d points of previously detected points or None
    :param debug: debug flag to display the pointcloud data
    :return: object dict for grasper with obj['points'] 4 x 3 array with 8 XYZ points of the grasper
    """

    xyz_h = np.concatenate([xyz, np.ones([xyz.shape[0], 1], dtype=np.float32)], axis=1)
    # Consider using PCA instead

    if prev_points is None:
        P1, r_xyz_h, t_xyz_h, plane_conf = find_first_plane(xyz_h, [])
    else:
        P1, r_xyz_h, t_xyz_h, plane_conf = find_first_plane(xyz_h, [prev_points])

    if t_xyz_h.shape[0] < 4:
        return None

    out = project_grasper_onto_plane(P1, xyz_h)
    obj = {'type': 'grasper', 'points': out}

    if debug:
        pts = np.concatenate([t_xyz_h[:, :3], r_xyz_h[:, :3]], axis=0)
        a = np.concatenate([np.repeat([[0, 1, 0]], len(t_xyz_h), axis=0), np.repeat([[1, 0, 0]], len(r_xyz_h), axis=0)],
                           axis=0)
        viewer = pptk.viewer(pts[:, :3])
        viewer.attributes(a)
        viewer.set(point_size=0.0005)
        viewer.wait()

    return obj
